Skip to content

[Feature] Support per-draft-model MoE backend via --speculative-config#37880

Merged
hmellor merged 11 commits intovllm-project:mainfrom
askliar:feature/different_moe_backend_for_specdec
Mar 25, 2026
Merged

[Feature] Support per-draft-model MoE backend via --speculative-config#37880
hmellor merged 11 commits intovllm-project:mainfrom
askliar:feature/different_moe_backend_for_specdec

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented Mar 23, 2026

Summary

Adds a moe_backend field to SpeculativeConfig, allowing users to select a different MoE kernel backend for the draft/proposer model than the one used by the target (generator) model. The override is applied through a new create_vllm_config_for_spec_decode() base config builder, with create_vllm_config_for_draft_model() extending it for standalone draft models.

Motivation

In speculative decoding setups where the generator is quantized (e.g. FP8/NvFP4) but the drafter is unquantized (or vice-versa), the optimal MoE backend for each model can differ. Today vLLM applies a single --moe-backend to both, forcing users to compromise or rely on "auto" heuristics that may not be ideal for the drafter. This change lets users explicitly control the drafter's MoE backend independently.

Design

  • create_vllm_config_for_spec_decode() (new, in utils.py): base config builder that applies kernel-level overrides (currently moe_backend). Used by all proposers.
  • create_vllm_config_for_draft_model() (refactored): extends the above with quant_config, parallel_config, and model_config overrides specific to standalone draft models.
  • SpecDecodeBaseProposer._create_draft_vllm_config() (new method): calls create_vllm_config_for_spec_decode. Overridden by DraftModelProposer to call create_vllm_config_for_draft_model instead.
  • Medusa: calls create_vllm_config_for_spec_decode directly (does not extend SpecDecodeBaseProposer).

Usage

vllm serve <target_model> \
  --moe-backend flashinfer_trtllm \
  --speculative-config '{"method":"mtp","num_speculative_tokens":1,"moe_backend":"triton"}'

Test plan

[x] Parametrized test_draft_moe_backend covering draft_model, eagle, and nemotron_mtp methods × 3 backend scenarios (override, inherit, default auto).
[x] test_create_vllm_config_for_spec_decode_noop_without_spec_config for the no-speculative-config edge case.
[x] Existing speculative decoding tests pass unchanged.
[x] E2E test with small eagle model on H100

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a moe_backend option within SpeculativeConfig to enable distinct MoE kernel backends for draft models in speculative decoding. The implementation correctly integrates this new configuration, applying it during the creation of the draft model's configuration. The changes are logical and well-implemented. After a thorough review, I found no issues of high or critical severity.

Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not seem to apply to all LLM speculators, only draft models. Why?

Andrii Skliar added 3 commits March 23, 2026 17:21
…veConfig

- Introduced `moe_backend` attribute to `SpeculativeConfig` to specify the MoE backend for draft models.
- Updated `create_vllm_config_for_draft_model` to handle the new `moe_backend` setting, ensuring compatibility between drafter and generator configurations.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
… model tests

- Added `apply_draft_moe_backend` function to override `moe_backend` in `VllmConfig` based on `speculative_config`.
- Updated `eagle.py` and `medusa.py` to utilize the new utility for model configuration.
- Introduced parameterized tests in `test_spec_decode.py` to validate the behavior of draft model configurations with various `moe_backend` scenarios.

Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar force-pushed the feature/different_moe_backend_for_specdec branch from c476e8c to 5eb53f9 Compare March 23, 2026 16:21
… assertions

- Updated `test_spec_decode.py` to improve parameterization of draft model tests, allowing for more comprehensive validation of `moe_backend` behavior.
- Removed redundant parameters and assertions, streamlining the test logic.
- Added a new test to verify that `apply_draft_moe_backend` behaves as a no-op when no speculative configuration is provided.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Comment thread vllm/v1/spec_decode/utils.py Outdated
return new_slot_mapping


def apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you consolidate this into create_vllm_config_for_spec_decode and then have create_vllm_config_for_draft_model extend it?

Probably this should also be made into a method on SpecDecodeBaseProposer that can be overrided by the draft model class. Could you make that change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, lmk what you think!

Andrii Skliar added 2 commits March 24, 2026 11:47
…r_spec_decode

- Updated `test_spec_decode.py` to utilize `create_vllm_config_for_spec_decode` in place of `apply_draft_moe_backend`, enhancing clarity in test assertions.
- Refactored `draft_model.py`, `eagle.py`, and `medusa.py` to adopt the new configuration utility, ensuring consistent model loading behavior.
- Introduced `create_vllm_config_for_spec_decode` to apply kernel-level overrides from speculative configurations.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
- Removed the `create_vllm_config_for_draft_model` and `create_vllm_config_for_spec_decode` utility functions, integrating their logic directly into the relevant classes.
- Updated `DraftModelProposer`, `SpecDecodeBaseProposer`, and `MedusaProposer` to utilize the `replace` function for configuration overrides, enhancing clarity and maintainability.
- Refactored tests in `test_spec_decode.py` to align with the new configuration approach, ensuring accurate validation of `moe_backend` propagation.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @askliar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 24, 2026
…ent_moe_backend_for_specdec

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@mergify mergify Bot removed the needs-rebase label Mar 24, 2026
… handling

- Updated variable names for clarity and consistency in test_spec_decode.py, enhancing readability.
- Simplified the model loading process in MedusaProposer by directly using the vllm_config without unnecessary replacements.
- Ensured that the speculative configuration is correctly propagated in the tests, maintaining the integrity of the model's behavior.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 24, 2026
…dling

- Introduced a new helper function to apply draft moe_backend logic, improving test clarity and reducing redundancy.
- Added tests to verify the correct propagation and inheritance of moe_backend settings between target and draft configurations.
- Ensured that default behaviors for moe_backend are correctly validated in various scenarios.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar marked this pull request as draft March 25, 2026 07:03
- Introduced a copy of the draft_parallel_config to maintain the original configuration while updating the rank based on the current vllm_config.
- Simplified the configuration replacement logic for better clarity and maintainability.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar marked this pull request as ready for review March 25, 2026 10:18
Copy link
Copy Markdown
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use vllm.config.utils.replace instead of dataclasses.replace. dataclasses.replace is not guaranteed to work on Pydantic dataclasses (which all our config classes are).

- Removed unnecessary imports and streamlined the usage of the `replace` function across `test_spec_decode.py`, `draft_model.py`, and `eagle.py`.
- Enhanced clarity in the configuration handling within `DraftModelProposer` by directly utilizing the `replace` function for updating the draft parallel configuration.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar
Copy link
Copy Markdown
Contributor Author

askliar commented Mar 25, 2026

@hmellor done, please, take a look!

Copy link
Copy Markdown
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the change!

@hmellor hmellor enabled auto-merge (squash) March 25, 2026 13:44
@hmellor hmellor merged commit cd76430 into vllm-project:main Mar 25, 2026
60 checks passed
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
…ig` (vllm-project#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…ig` (vllm-project#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…ig` (vllm-project#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…ig` (vllm-project#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Apr 8, 2026
### What this PR does / why we need it?
Main2main upgrade vllm to 0330
fix breaks:
1. vllm-project/vllm#37728 add clear_row method
for BlockTable
2. vllm-project/vllm#37975 Adapt
GatedDeltaNetAttention Refactor
3. vllm-project/vllm#37698 update
maybe_update_config in vllm_ascend/quantization/modelslim_config.py to
adapt this pr change
4. vllm-project/vllm#37880 This pr add the feat
where we can set different moe backends between draft and target model,
we should overwrite it in the draft proposer
5. vllm-project/vllm#37853 for now just to skip
test_cpu_offloading.py test case utils this feature has been adapted.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

CI

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@29e4870

---------

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: Claude Code <claude@anthropic.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
Co-authored-by: wxsIcey <1790571317@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…ig` (vllm-project#37880)

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: [Andrii Skliar] <askliar@nvidia.com>
Co-authored-by: Andrii Skliar <askliar@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants